{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "log = logging.getLogger(__name__)\n",
    "\n",
    "SCRAPER_DIR = 'pics'\n",
    "TRAIN_FOLDER = 'pics/train_cards'\n",
    "VALIDATE_FOLDER = 'pics/validate_cards'\n",
    "TEST_FOLDER = 'pics/test_cards'\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "sys.path.insert(0, os.path.abspath('..'))\n",
    "from poker.scraper.table_scraper_nn import CardNeuralNetwork"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 52/52 [00:09<00:00,  5.21it/s]\n",
      "100%|██████████| 52/52 [00:09<00:00,  5.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 25429 images belonging to 52 classes.\n",
      "Found 25362 images belonging to 52 classes.\n",
      "Metal device set to: Apple M2 Pro\n",
      "\n",
      "systemMemory: 32.00 GB\n",
      "maxCacheSize: 10.67 GB\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-06-20 21:44:13.159095: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.\n",
      "2023-06-20 21:44:13.159735: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " conv2d (Conv2D)             (None, 50, 15, 64)        1792      \n",
      "                                                                 \n",
      " conv2d_1 (Conv2D)           (None, 50, 15, 64)        36928     \n",
      "                                                                 \n",
      " max_pooling2d (MaxPooling2D  (None, 25, 7, 64)        0         \n",
      " )                                                               \n",
      "                                                                 \n",
      " dropout (Dropout)           (None, 25, 7, 64)         0         \n",
      "                                                                 \n",
      " conv2d_2 (Conv2D)           (None, 25, 7, 128)        73856     \n",
      "                                                                 \n",
      " conv2d_3 (Conv2D)           (None, 25, 7, 128)        147584    \n",
      "                                                                 \n",
      " max_pooling2d_1 (MaxPooling  (None, 12, 3, 128)       0         \n",
      " 2D)                                                             \n",
      "                                                                 \n",
      " dropout_1 (Dropout)         (None, 12, 3, 128)        0         \n",
      "                                                                 \n",
      " conv2d_4 (Conv2D)           (None, 12, 3, 256)        295168    \n",
      "                                                                 \n",
      " conv2d_5 (Conv2D)           (None, 12, 3, 256)        590080    \n",
      "                                                                 \n",
      " max_pooling2d_2 (MaxPooling  (None, 6, 1, 256)        0         \n",
      " 2D)                                                             \n",
      "                                                                 \n",
      " dropout_2 (Dropout)         (None, 6, 1, 256)         0         \n",
      "                                                                 \n",
      " flatten (Flatten)           (None, 1536)              0         \n",
      "                                                                 \n",
      " dense (Dense)               (None, 2048)              3147776   \n",
      "                                                                 \n",
      " dropout_3 (Dropout)         (None, 2048)              0         \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, 1024)              2098176   \n",
      "                                                                 \n",
      " dropout_4 (Dropout)         (None, 1024)              0         \n",
      "                                                                 \n",
      " dense_2 (Dense)             (None, 52)                53300     \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 6,444,660\n",
      "Trainable params: 6,444,660\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Epoch 1/50\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-06-20 21:44:13.541140: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n",
      "2023-06-20 21:44:13.875994: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "199/199 [==============================] - ETA: 0s - loss: 3.1090 - accuracy: 0.1287"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-06-20 21:44:24.432235: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "199/199 [==============================] - 20s 96ms/step - loss: 3.1090 - accuracy: 0.1287 - val_loss: 2.7084 - val_accuracy: 0.2593\n",
      "Epoch 2/50\n",
      "199/199 [==============================] - 18s 91ms/step - loss: 1.4316 - accuracy: 0.5396 - val_loss: 1.3452 - val_accuracy: 0.6046\n",
      "Epoch 3/50\n",
      "199/199 [==============================] - 18s 90ms/step - loss: 0.4745 - accuracy: 0.8413 - val_loss: 0.6106 - val_accuracy: 0.8295\n",
      "Epoch 4/50\n",
      "199/199 [==============================] - 19s 95ms/step - loss: 0.2253 - accuracy: 0.9250 - val_loss: 0.5851 - val_accuracy: 0.8307\n",
      "Epoch 5/50\n",
      "199/199 [==============================] - 19s 96ms/step - loss: 0.1481 - accuracy: 0.9518 - val_loss: 0.2995 - val_accuracy: 0.9184\n",
      "Epoch 6/50\n",
      "199/199 [==============================] - 19s 93ms/step - loss: 0.1069 - accuracy: 0.9669 - val_loss: 0.1781 - val_accuracy: 0.9496\n",
      "Epoch 7/50\n",
      "199/199 [==============================] - 19s 93ms/step - loss: 0.0908 - accuracy: 0.9727 - val_loss: 0.1818 - val_accuracy: 0.9503\n",
      "Epoch 8/50\n",
      "199/199 [==============================] - 20s 102ms/step - loss: 0.0890 - accuracy: 0.9714 - val_loss: 0.2145 - val_accuracy: 0.9413\n",
      "Epoch 9/50\n",
      "199/199 [==============================] - 20s 102ms/step - loss: 0.0697 - accuracy: 0.9773 - val_loss: 0.1006 - val_accuracy: 0.9735\n",
      "Epoch 10/50\n",
      "199/199 [==============================] - 19s 96ms/step - loss: 0.0678 - accuracy: 0.9788 - val_loss: 0.1634 - val_accuracy: 0.9606\n",
      "Epoch 11/50\n",
      "199/199 [==============================] - 19s 95ms/step - loss: 0.0537 - accuracy: 0.9838 - val_loss: 0.3071 - val_accuracy: 0.9058\n",
      "Epoch 12/50\n",
      "199/199 [==============================] - 19s 97ms/step - loss: 0.0577 - accuracy: 0.9831 - val_loss: 0.2502 - val_accuracy: 0.9339\n",
      "Epoch 13/50\n",
      "199/199 [==============================] - 19s 94ms/step - loss: 0.0601 - accuracy: 0.9815 - val_loss: 0.0625 - val_accuracy: 0.9843\n",
      "Epoch 14/50\n",
      "199/199 [==============================] - 19s 97ms/step - loss: 0.0543 - accuracy: 0.9839 - val_loss: 0.1312 - val_accuracy: 0.9664\n",
      "Epoch 15/50\n",
      "199/199 [==============================] - 19s 97ms/step - loss: 0.0491 - accuracy: 0.9851 - val_loss: 0.0999 - val_accuracy: 0.9772\n",
      "Epoch 16/50\n",
      "199/199 [==============================] - 19s 96ms/step - loss: 0.0424 - accuracy: 0.9873 - val_loss: 0.0492 - val_accuracy: 0.9879\n",
      "Epoch 17/50\n",
      "199/199 [==============================] - 19s 94ms/step - loss: 0.0530 - accuracy: 0.9856 - val_loss: 0.1115 - val_accuracy: 0.9742\n",
      "Epoch 18/50\n",
      "199/199 [==============================] - 18s 92ms/step - loss: 0.0593 - accuracy: 0.9840 - val_loss: 0.3231 - val_accuracy: 0.9159\n",
      "Epoch 19/50\n",
      "199/199 [==============================] - 19s 97ms/step - loss: 0.0402 - accuracy: 0.9887 - val_loss: 0.0673 - val_accuracy: 0.9830\n",
      "Epoch 20/50\n",
      "199/199 [==============================] - 19s 93ms/step - loss: 0.0575 - accuracy: 0.9846 - val_loss: 0.0373 - val_accuracy: 0.9889\n",
      "Epoch 21/50\n",
      "199/199 [==============================] - 19s 97ms/step - loss: 0.0519 - accuracy: 0.9855 - val_loss: 0.1266 - val_accuracy: 0.9680\n",
      "Epoch 22/50\n",
      "199/199 [==============================] - 19s 96ms/step - loss: 0.0414 - accuracy: 0.9891 - val_loss: 0.0780 - val_accuracy: 0.9779\n",
      "Epoch 23/50\n",
      "199/199 [==============================] - 19s 95ms/step - loss: 0.0438 - accuracy: 0.9886 - val_loss: 0.1230 - val_accuracy: 0.9674\n",
      "Epoch 24/50\n",
      "199/199 [==============================] - 19s 96ms/step - loss: 0.0364 - accuracy: 0.9895 - val_loss: 0.0721 - val_accuracy: 0.9787\n",
      "Epoch 25/50\n",
      "199/199 [==============================] - 19s 95ms/step - loss: 0.0487 - accuracy: 0.9874 - val_loss: 0.1245 - val_accuracy: 0.9708\n",
      "Epoch 25: early stopping\n",
      "52/52 [==============================] - 2s 43ms/step - loss: 0.1217 - accuracy: 0.9701\n",
      "Validation loss: 0.12174148112535477\n",
      "Validation accuracy: 0.9701021909713745\n"
     ]
    }
   ],
   "source": [
    "table_name = \"test\"\n",
    "log.info(f\"Start trainig for {table_name}\")\n",
    "\n",
    "n = CardNeuralNetwork()\n",
    "log.info(f\"Creating augmented images in {TRAIN_FOLDER}\")\n",
    "n.create_augmented_images(table_name)\n",
    "n.train_neural_network()\n",
    "# n.save_model_to_disk()\n",
    "# n.save_model_to_db(stable_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 5 of 6). These functions will not be directly callable after loading.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /var/folders/54/bgjlgvk9363ghkydvgjpk5p40000gn/T/tmpi6ml6284/dickreuter/poker-card-classification/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /var/folders/54/bgjlgvk9363ghkydvgjpk5p40000gn/T/tmpi6ml6284/dickreuter/poker-card-classification/assets\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "937235229cfd4ef6aa9046f08f13aa91",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7878cb49a5834bdf99b97d7a10e33217",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "keras_metadata.pb:   0%|          | 0.00/32.8k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a5f2faa3d43a4ad2be81bfe94913b5f1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "saved_model.pb:   0%|          | 0.00/254k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9f0dcb715c32407b8687bc225b8b72f5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "variables.data-00000-of-00001:   0%|          | 0.00/25.8M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'https://huggingface.co/dickreuter/poker-card-classification/tree/main/'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from huggingface_hub import push_to_hub_keras\n",
    "token = ''\n",
    "push_to_hub_keras(n.model, \n",
    "    \"dickreuter/poker-card-classification\", token=token,\n",
    "    tags = [\"poker-card-classification\", \"pokerbot\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "poker310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
